{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### imports\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch, math\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "from torch import optim\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "\n",
    "from transformers import get_linear_schedule_with_warmup,get_cosine_with_hard_restarts_schedule_with_warmup\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "from gensim import corpora, similarities, models\n",
    "from gensim.matutils import corpus2csc\n",
    "\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import base64\n",
    "from tqdm.auto import tqdm\n",
    "import pickle, random\n",
    "import matplotlib.pyplot as plt\n",
    "from multiprocessing import Pool,shared_memory\n",
    "import sys, csv, json, os, gc, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_img_len = 30"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load data\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trash = {'!', '$', \"'ll\", \"'s\", ',', '&', ':', 'and', 'cut', 'is', 'are', 'was'}\n",
    "trash_replace = ['\"hey siri, play some', 'however, ', 'yin and yang, ',\n",
    "                 'shopping mall/']\n",
    "\n",
    "def process(x):\n",
    "    tmp = x.split()\n",
    "    if tmp[0] in trash: x = ' '.join(tmp[1:])\n",
    "    if tmp[0][0] == '-': x = x[1:]\n",
    "    for tr in trash_replace:\n",
    "        x = x.replace(tr, '')\n",
    "    return x\n",
    "\n",
    "def normalize(x):\n",
    "    ret = x['boxes'].copy()\n",
    "    ret[:,0] /= x['image_h']\n",
    "    ret[:,1] /= x['image_w']\n",
    "    ret[:,2] /= x['image_h']\n",
    "    ret[:,3] /= x['image_w']\n",
    "    wh = (ret[:,2]-ret[:,0]) * (ret[:,3]-ret[:,1])\n",
    "    wh2 = (ret[:,2]-ret[:,0]) / (ret[:,3]-ret[:,1]+1e-6)\n",
    "    ret = np.hstack((ret, wh.reshape(-1,1), wh2.reshape(-1,1)))\n",
    "    return ret\n",
    "\n",
    "def load_data(file_name, reset=False, decode=True):\n",
    "    ret = pd.read_csv(file_name, sep='\\t')\n",
    "    if decode:\n",
    "        ret['boxes'] = ret['boxes'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.float32).reshape(-1, 4))\n",
    "        ret['features'] = ret['features'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.float32).reshape(-1, 2048))\n",
    "        ret['class_labels'] = ret['class_labels'].apply(lambda x: np.frombuffer(base64.b64decode(x), dtype=np.int64).reshape(-1, 1))\n",
    "        ret['boxes'] = ret.apply(lambda x: normalize(x), axis=1)\n",
    "        ret['features'] = ret.apply(lambda x: np.concatenate((x['class_labels'], x['features'], x['boxes']), axis=1)[:max_img_len], axis=1)\n",
    "    ret['query'] = ret['query'].apply(lambda x: process(x))\n",
    "    # reset query_id\n",
    "    if reset:\n",
    "        query2qid = {query: qid for qid, (query, _) in enumerate(tqdm(ret.groupby('query')))}\n",
    "        ret['query_id'] = ret.apply(lambda x: query2qid[x['query']], axis=1)\n",
    "    return ret[['features']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shared(i):\n",
    "    name='train_{}'.format(i)\n",
    "    print(name)\n",
    "    df=load_data(path+'data/train_{}.tsv'.format(i))\n",
    "#     df=load_data(path+'train.sample.tsv'.format(i))\n",
    "    a = np.vstack(df['features'].values)\n",
    "    shm = shared_memory.SharedMemory(create=True, size=a.nbytes,name=name)\n",
    "    b = np.ndarray(a.shape, dtype=a.dtype, buffer=shm.buf)\n",
    "    b[:] = a[:] \n",
    "    shape_list=[0]\n",
    "    for _,row in tqdm(df['features'].iteritems()):\n",
    "        shape_list.append(shape_list[-1]+row.shape[0])\n",
    "    file = open('shape_list_{}'.format(i), 'wb')\n",
    "    pickle.dump(shape_list, file)\n",
    "    file.close() \n",
    "    return shm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_0\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56ce74eefe254b8e994c510a8c8f030c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train_1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b4e11e124b4c40b160fb935f101346",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train_2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48a3f038ee1d4dddb3413b92928e2b9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train_3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7131521232b46d287ebd1ee909626fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train_4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86c4566797e144a59ccad7f2d73b0ec8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "ZeroDivisionError",
     "evalue": "division by zero",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mZeroDivisionError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-0dfc50d50cff>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mlst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshared\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;36m1\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mZeroDivisionError\u001b[0m: division by zero"
     ]
    }
   ],
   "source": [
    "lst=[]\n",
    "for i in range(20):\n",
    "    lst.append(shared(i))\n",
    "1/0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "    try:\n",
    "        shm = shared_memory.SharedMemory(name='train_{}'.format(i))\n",
    "        shm.close()\n",
    "        shm.unlink()\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (myenv)",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
